# -*- coding:utf-8 -*-
import backtrader as bt
#####################
import pandas as pd
import os
import datetime
import matplotlib.pyplot as plt

class MyStrategy(bt.Strategy):
    params = (
        ('short_period', 5),
        ('long_period', 20),
    )

    def __init__(self):
        self.dataclose = self.datas[0].close
        self.sma_5 = bt.indicators.SimpleMovingAverage(
            self.dataclose, period=self.params.short_period)
        self.sma_60 = bt.indicators.SimpleMovingAverage(self.dataclose, period=self.params.long_period)

        self.buy_singal = bt.indicators.CrossOver(self.sma_5, self.sma_60)
        self.sell_singal = bt.indicators.CrossDown(self.sma_5, self.sma_60)
        self.buy_singal.plotinfo.plot = False
        self.sell_singal.plotinfo.plot = False


    def start(self):
        pass
        # print(f"start!___{self.datas[0].datetime.date(0)}")


    def prenext(self):
        pass
        # print(f"prenext___{self.datas[0].datetime.date(0)}")

    def nextstart(self):
        pass
        # print(f'nextstart___{self.datas[0].datetime.date(0)}')

    def next(self):
        if not self.position:
            if self.buy_singal[0] > 0:
                self.order = self.buy()
                print(f"{self.data0.datetime.date(0)},买入！价格为{self.data0.close[0]}")
        else:
            if self.sell_singal[0] > 0:
                self.order = self.sell()
                print(f"{self.data0.datetime.date(0)},卖出！价格为{self.data0.close[0]}")

    def stop(self):
        print(f"stop___{self.datas[0].datetime.date(0)}")


if __name__ == '__main__':
    cerebro = bt.Cerebro()
    cerebro.broker.set_cash(100000.00)  # 设置初始资金金额
    cerebro = bt.Cerebro(stdstats=False)
    cerebro.addobserver(bt.observers.Broker)
    # cerebro.addobserver(bt.observers.Trades)
    cerebro.addobserver(bt.observers.BuySell)
    # cerebro.addobserver(bt.observers.DrawDown)
    cerebro.addobserver(bt.observers.Value)
    # cerebro.addobserver(bt.observers.TimeReturn)

    init_fund = cerebro.broker.getvalue()
    print(f'初始资金:{init_fund}')

    filename = 'sz000002.csv'
    print(filename)
    data = pd.read_csv(filename, index_col="date", parse_dates=True)
    # print(data)
    # 也可以如下处理日期
    # data.index=pd.to_datetime(data.date)
    # data.drop(columns=["date"],inplace=True)
    # print(data)

    load_data = bt.feeds.PandasData(dataname=data, fromdate=datetime.datetime(2020, 1, 1), todate=datetime.datetime(2020, 10, 18))
    cerebro.adddata(load_data)
    cerebro.addstrategy(MyStrategy)

    #####################
    cerebro.run()
    end_fund = cerebro.broker.getvalue()
    print(f'期末资金:{end_fund}')

    """
    处理错误
    cannot import name 'warnings' from 'matplotlib.dates
    pip uninstall matplotlib
    pip install matplotlib==3.2.2
    """
    cerebro.plot(style="candle")